{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(tutorial-tensor-expr-get-started)=\n",
        "# 使用张量表达式处理算子\n",
        "\n",
        "**Author**: [Tianqi Chen](https://tqchen.github.io)\n",
        "\n",
        "在本教程中，把注意力转向 TVM 如何使用张量表达式（Tensor Expression，简称 TE）定义张量计算并应用循环优化。TE 以纯函数式语言描述张量计算（即每个表达式都没有副作用）。从 TVM 的整体来看，Relay 将计算描述为一组算子，这些算子都可以表示为 TE 表达式，每个 TE 表达式都接受输入张量并生成输出张量。\n",
        "\n",
        "这是关于 TVM 中张量表达式语言的介绍性教程。TVM 使用领域专用张量表达式来进行有效的内核构建。通过两个使用张量表达式语言的例子，演示基本工作流程。第一个例子介绍了 TE 和用向量加法进行调度。第二个例子扩展了这些概念，用 TE 逐步优化矩阵乘法。这个矩阵乘法的例子将作为未来涵盖 TVM 更高级功能的教程的基础。\n",
        "\n",
        "## 例 1：为 CPU 编写和调度 TE 中的向量加法\n",
        "\n",
        "让我们看看 Python 中的例子，将实现向量加法的 TE，然后是针对 CPU 的调度。首先初始化 TVM 环境。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tvm\n",
        "from tvm import te"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "如果你能确定你所针对的 CPU 并指定它，你将获得更好的性能。如果你使用 LLVM，你可以从命令 ``llc --version`` 中得到这个信息，以获得 CPU 类型，你可以检查 ``/proc/cpuinfo``，了解你的处理器可能支持的额外扩展。例如，你可以使用 ``llvm -mcpu=skylake-avx512`` 来获取带有 AVX-512 指令的 CPU。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tgt = tvm.target.Target(target=\"llvm\", host=\"llvm\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 描述张量计算\n",
        "\n",
        "描述矢量加法的计算。TVM 采用了张量语义，每个中间结果都表示为一个多维数组。用户需要描述生成张量的计算规则。首先定义符号变量 `n` 来表示形状。然后定义两个占位符张量 `A` 和 `B`，具有给定的形状 `(n,)`。然后用 ``compute`` 算子来描述结果张量 `C`。``compute`` 定义了计算，其输出符合指定的张量形状，计算将在张量中的每个位置进行，由 `lambda` 函数定义。注意，虽然 `n` 是变量，但它定义了 `A`、`B` 和 `C` 张量之间的一致形状。记住，在这个阶段没有实际的计算发生，因为只是声明了计算应该如何进行。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = te.var(\"n\")\n",
        "A = te.placeholder((n,), name=\"A\")\n",
        "B = te.placeholder((n,), name=\"B\")\n",
        "C = te.compute(A.shape, lambda i: A[i] + B[i], name=\"C\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} Lambda 函数\n",
        "`te.compute` 方法的第二个参数是执行计算的函数。在这个例子中，使用匿名函数（也被称为 `lambda` 函数）来定义计算，在本例中是对 `A` 和 `B` 的第 `i` 个元素进行加法。\n",
        "```\n",
        "\n",
        "### 为计算创建默认的调度\n",
        "\n",
        "虽然上面几行描述了计算规则，但可以用许多不同的方式计算 `C`，以适应不同的设备。对于有多个轴的张量，你可以选择先迭代哪个轴，或者计算可以分成不同的线程。TVM 要求用户提供调度，这是关于计算应该如何进行的描述。TE 中的调度操作可以改变循环顺序，在不同的线程中分割计算，并将数据块分组，以及其他操作。调度背后的重要概念是，它们只描述计算是如何进行的，所以同一个 TE 的不同调度会产生相同的结果。\n",
        "\n",
        "TVM 允许创建自然的调度，通过以行为单位迭代的方式进行 `C` 运算。\n",
        "\n",
        "````{note}\n",
        "类似于 C 语言实现：\n",
        "```c\n",
        "for (int i = 0; i < n; ++i) {\n",
        "  C[i] = A[i] + B[i];\n",
        "}\n",
        "```\n",
        "````"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "s = te.create_schedule(C.op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 为计算创建默认调度\n",
        "\n",
        "有了 TE 表达式和调度，就可以为目标语言和架构（在这里是指 LLVM 和 CPU）生成可运行的代码。向 TVM 提供调度、调度中的 TE 表达式的列表、目标和主机，以及我们要产生的函数的名称。输出的结果是类型消除的（type-erased）函数，可以直接从 Python 中调用。\n",
        "\n",
        "在下面一行，使用 `tvm.build` 来创建函数。`build` 函数需要调度、所需的函数签名（包括输入和输出）以及要编译的目标语言。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fadd = tvm.build(s, [A, B, C], tgt, name=\"myadd\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "运行这个函数，并将其输出与 `numpy` 中的相同计算进行比较。编译后的 TVM 函数暴露了简洁的 C 语言 API，可以从任何语言调用。首先创建设备，也就是 TVM 可以编译调度的设备（本例中为 CPU）。在本例中，该设备是 LLVM CPU 目标。然后可以初始化设备中的张量，并执行自定义的加法运算。为了验证计算是否正确，我们可以将 `c` 张量的输出结果与 `numpy` 进行的相同计算进行比较。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dev = tvm.device(tgt.kind.name, 0)\n",
        "\n",
        "n = 1024\n",
        "a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
        "b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
        "c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
        "fadd(a, b, c)\n",
        "np.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "为了得到这个版本与 numpy 相比有多快的比较，创建辅助函数来运行 TVM 生成代码的配置文件。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Numpy running time: 0.000017\n",
            "naive: 0.000015\n"
          ]
        }
      ],
      "source": [
        "import timeit\n",
        "\n",
        "np_repeat = 100\n",
        "np_running_time = timeit.timeit(\n",
        "    setup=\"import numpy\\n\"\n",
        "    \"n = 32768\\n\"\n",
        "    'dtype = \"float32\"\\n'\n",
        "    \"a = numpy.random.rand(n, 1).astype(dtype)\\n\"\n",
        "    \"b = numpy.random.rand(n, 1).astype(dtype)\\n\",\n",
        "    stmt=\"answer = a + b\",\n",
        "    number=np_repeat,\n",
        ")\n",
        "print(\"Numpy running time: %f\" % (np_running_time / np_repeat))\n",
        "\n",
        "\n",
        "def evaluate_addition(func, target, optimization, log):\n",
        "    dev = tvm.device(target.kind.name, 0)\n",
        "    n = 32768\n",
        "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
        "    b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
        "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
        "\n",
        "    evaluator = func.time_evaluator(func.entry_name, dev, number=10)\n",
        "    mean_time = evaluator(a, b, c).mean\n",
        "    print(\"%s: %f\" % (optimization, mean_time))\n",
        "\n",
        "    log.append((optimization, mean_time))\n",
        "\n",
        "\n",
        "log = [(\"numpy\", np_running_time / np_repeat)]\n",
        "evaluate_addition(fadd, tgt, \"naive\", log=log)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 更新调度以使用并行\n",
        "\n",
        "现在已经说明了 TE 的基本原理，更深入地了解调度的作用，以及如何使用它们来为不同的架构调度张量表达式。调度是一系列应用于表达式的步骤，以多种不同的方式对其进行转换。当调度应用于 TE 中的表达式时，输入和输出保持不变，但在编译时，表达式的实现可以改变。在默认的调度中，这个张量加法是并行运行的，但是很容易在所有的处理器线程中进行并行化。可以将并行调度的操作应用到计算中。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "s[C].parallel(C.op.axis[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "`tvm.lower` 命令将生成 TE 的中间表示（IR），以及相应的调度。通过在应用不同的调度操作时降低表达式，可以看到调度对计算的顺序的影响。使用旗标 `simple_mode=True` 来返回一个可读的 C 风格语句。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        n <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        stride <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(A, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        stride_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(B, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride_1,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        stride_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(C, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride_2,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(n):\n",
              "            C_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride_2 <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "            A_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "            B_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "            C_2[i <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_2] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A_2[i <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride] <span style=\"color: #AA22FF; font-weight: bold\">+</span> B_2[i <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_1]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在，TVM 有可能在独立的线程上运行这些块。编译并运行这个应用了并行操作的新调度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "parallel: 0.000007\n"
          ]
        }
      ],
      "source": [
        "fadd_parallel = tvm.build(s, [A, B, C], tgt, name=\"myadd_parallel\")\n",
        "fadd_parallel(a, b, c)\n",
        "\n",
        "np.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())\n",
        "evaluate_addition(fadd_parallel, tgt, \"parallel\", log=log)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 更新调度以使用矢量化\n",
        "\n",
        "现代的 CPU 也有能力对浮点值进行 SIMD 操作，我们可以对我们的计算表达式应用另一个调度，以利用这一优势。实现这一点需要多个步骤：首先，我们必须使用分割调度原语将调度分割成内循环和外循环。内循环可以使用矢量化调度原语来使用 SIMD 指令，然后外循环可以使用并行调度原语来并行化。选择分割因子为你的 CPU 上的线程数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "vector: 0.000032\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        n <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        stride <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(A, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        stride_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(B, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride_1,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        stride_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(C, (n,), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride_2,), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel((n <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3</span>) <span style=\"color: #AA22FF; font-weight: bold\">//</span> <span style=\"color: #008000\">4</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> i_inner_s <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">4</span>):\n",
              "                <span style=\"color: #008000; font-weight: bold\">if</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>likely(i_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_inner_s <span style=\"color: #AA22FF; font-weight: bold\">&lt;</span> n):\n",
              "                    C_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride_2 <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "                    A_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "                    B_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> n,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "                    cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> i_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_inner_s\n",
              "                    C_2[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_2] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        A_2[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride] <span style=\"color: #AA22FF; font-weight: bold\">+</span> B_2[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_1]\n",
              "                    )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# 重新创建调度，因为我们用并行操作修改了它\n",
        "n = te.var(\"n\")\n",
        "A = te.placeholder((n,), name=\"A\")\n",
        "B = te.placeholder((n,), name=\"B\")\n",
        "C = te.compute(A.shape, lambda i: A[i] + B[i], name=\"C\")\n",
        "\n",
        "s = te.create_schedule(C.op)\n",
        "\n",
        "# 这个 factor 应该被选择来匹配适合你的 CPU 的线程数。\n",
        "# 这将根据架构的不同而变化，\n",
        "# 但一个好的规则是将这个 factor 设置为等于可用的 CPU 内核数。\n",
        "factor = 4\n",
        "\n",
        "outer, inner = s[C].split(C.op.axis[0], factor=factor)\n",
        "s[C].parallel(outer)\n",
        "s[C].vectorize(inner)\n",
        "\n",
        "fadd_vector = tvm.build(s, [A, B, C], tgt, name=\"myadd_parallel\")\n",
        "\n",
        "evaluate_addition(fadd_vector, tgt, \"vector\", log=log)\n",
        "\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 比较不同的调度\n",
        "\n",
        "我们现在可以比较不同的调度："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "            Operator\t              Timing\t         Performance\n",
            "               numpy\t1.720716245472431e-05\t                 1.0\n",
            "               naive\t1.4502000000000001e-05\t  0.8427885793580339\n",
            "            parallel\t7.3576000000000005e-06\t 0.42758938432524274\n",
            "              vector\t         3.18174e-05\t  1.8490788404955387\n"
          ]
        }
      ],
      "source": [
        "baseline = log[0][1]\n",
        "print(\"%s\\t%s\\t%s\" % (\"Operator\".rjust(20), \"Timing\".rjust(20), \"Performance\".rjust(20)))\n",
        "for result in log:\n",
        "    print(\n",
        "        \"%s\\t%s\\t%s\"\n",
        "        % (result[0].rjust(20), str(result[1]).rjust(20), str(result[1] / baseline).rjust(20))\n",
        "    )"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} 代码特殊化\n",
        "正如你可能已经注意到的，`A`、`B` 和 `C` 的声明都采取了相同的形状参数 `n`。TVM 将利用这一点，只向内核传递一个形状参数，正如你在打印的设备代码中发现的那样。这是特殊化的一种形式。\n",
        "\n",
        "在主机端，TVM 会自动生成检查代码，检查参数中的约束。所以如果你把不同形状的数组传入 `fadd`，就会出现错误。\n",
        "\n",
        "我们可以做更多的特殊化。例如，我们可以在计算声明中写 `n = tvm.runtime.convert(1024)`，而不是 `n = te.var(\"n\")` 。生成的函数将只接受长度为 1024 的向量。\n",
        "```\n",
        "\n",
        "我们已经定义、调度并编译了向量加法运算符，然后我们能够在 TVM 运行时上执行它。我们可以将运算符保存为一个库，然后我们可以在以后使用 TVM 运行时加载它。\n",
        "\n",
        "### 针对 GPU 的向量加法（可选）\n",
        "\n",
        "TVM 能够针对多种架构。在下一个例子中，将针对 GPU 的向量加法进行编译。\n",
        "\n",
        "目标更改为 GPU 后端。例如：`cuda` (NVIDIA GPU)、`rocm` (Radeon GPU)、`OpenCL` (OpenCL)。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.te.tensor.Tensor'>\n"
          ]
        }
      ],
      "source": [
        "gpu_tgt = tvm.target.Target(target=\"cuda\", host=\"llvm\")\n",
        "# 重建调度\n",
        "n = te.var(\"n\")\n",
        "A = te.placeholder((n,), name=\"A\")\n",
        "B = te.placeholder((n,), name=\"B\")\n",
        "C = te.compute(A.shape, lambda i: A[i] + B[i], name=\"C\")\n",
        "print(type(C))\n",
        "s = te.create_schedule(C.op)\n",
        "bx, tx = s[C].split(C.op.axis[0], factor=64)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "必须将迭代轴 `bx` 和 `tx` 绑定到 GPU 计算网格中的线程上。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "s[C].bind(bx, te.thread_axis(\"blockIdx.x\"))\n",
        "s[C].bind(tx, te.thread_axis(\"threadIdx.x\"))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "完成调度的指定后，便可将其编译成 TVM 函数。默认情况下，TVM 编译为可以直接从 python 端调用的类型擦除函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "fadd = tvm.build(s, [A, B, C], target=gpu_tgt, name=\"myadd\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "编译结果 `fadd` 是 GPU 设备函数（如果涉及GPU）以及调用 GPU 函数的 host 包装器。`fadd` 是生成的主机包装器函数，它在内部包含对生成的设备函数的引用。"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "编译后的 TVM 函数公开了可以从任何语言调用的简洁 C API。在 Python 中提供了最小的数组 API 来帮助快速测试和原型化。数组 API 基于 [DLPack](https://github.com/dmlc/dlpack) 标准。\n",
        "\n",
        "- 首先创建 GPU 设备。\n",
        "- 然后 `tvm.nd.array` 将数据复制到 GPU。\n",
        "- `fadd` 运行实际的计算\n",
        "- `numpy()` 将 GPU 数组复制回 CPU 以验证其正确性。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "dev = tvm.device(gpu_tgt.kind.name, 0)\n",
        "n = 1024\n",
        "a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
        "b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
        "c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
        "fadd(a, b, c)\n",
        "np.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 检查生成的 GPU 代码\n",
        "\n",
        "可以在 TVM 中检查生成的代码。`tvm.build` 是 TVM 模块。`fadd` 是包含 host 包装器的 host 模块，它还包含用于 CUDA (GPU) 功能的设备模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if (\n",
        "    gpu_tgt.kind.name == \"cuda\"\n",
        "    or tgt.kind.name == \"rocm\"\n",
        "    or tgt.kind.name.startswith(\"opencl\")\n",
        "):\n",
        "    dev_module = fadd.imported_modules[0]\n",
        "    print(\"-----GPU code-----\")\n",
        "    print(dev_module.get_source())\n",
        "else:\n",
        "    print(fadd.get_source())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 保存和加载已编译的模块\n",
        "\n",
        "除了运行时编译，我们还可以将编译后的模块保存到文件中，以后再加载回来。\n",
        "\n",
        "下面的代码首先执行了以下步骤：\n",
        "\n",
        "- 它将编译后的主机模块保存到一个对象文件中。\n",
        "- 然后它将设备模块保存到 ptx 文件中。\n",
        "- `cc.create_shared` 调用编译器（gcc）来创建共享库"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm.contrib import cc\n",
        "from tvm.contrib import utils\n",
        "\n",
        "temp = utils.tempdir()\n",
        "fadd.save(temp.relpath(\"myadd.o\"))\n",
        "if tgt.kind.name == \"cuda\":\n",
        "    fadd.imported_modules[0].save(temp.relpath(\"myadd.ptx\"))\n",
        "if tgt.kind.name == \"rocm\":\n",
        "    fadd.imported_modules[0].save(temp.relpath(\"myadd.hsaco\"))\n",
        "if tgt.kind.name.startswith(\"opencl\"):\n",
        "    fadd.imported_modules[0].save(temp.relpath(\"myadd.cl\"))\n",
        "cc.create_shared(temp.relpath(\"myadd.so\"), [temp.relpath(\"myadd.o\")])\n",
        "print(temp.listdir())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} 模块存储格式\n",
        "CPU（主机）模块被直接保存为共享库（`.so`）。设备代码可以有多种自定义格式。在我们的例子中，设备代码被保存在 ptx 中，还有一个元数据 json 文件。它们可以通过导入分离加载和链接。\n",
        "```\n",
        "\n",
        "### 加载已编译的模块\n",
        "\n",
        "我们可以从文件系统中加载编译好的模块并运行代码。下面的代码分别加载主机和设备模块，并将它们链接在一起。我们可以验证新加载的功能是否工作。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fadd1 = tvm.runtime.load_module(temp.relpath(\"myadd.so\"))\n",
        "if tgt.kind.name == \"cuda\":\n",
        "    fadd1_dev = tvm.runtime.load_module(temp.relpath(\"myadd.ptx\"))\n",
        "    fadd1.import_module(fadd1_dev)\n",
        "\n",
        "if tgt.kind.name == \"rocm\":\n",
        "    fadd1_dev = tvm.runtime.load_module(temp.relpath(\"myadd.hsaco\"))\n",
        "    fadd1.import_module(fadd1_dev)\n",
        "\n",
        "if tgt.kind.name.startswith(\"opencl\"):\n",
        "    fadd1_dev = tvm.runtime.load_module(temp.relpath(\"myadd.cl\"))\n",
        "    fadd1.import_module(fadd1_dev)\n",
        "\n",
        "fadd1(a, b, c)\n",
        "tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 把所有东西都装进库\n",
        "\n",
        "在上面的例子中，分别存储了设备和主机代码。TVM 也支持将所有东西作为共享库导出。在底层，将设备模块打包成二进制的 blob，并将它们与主机代码连接在一起。目前支持打包 Metal、OpenCL 和 CUDA 模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fadd.export_library(temp.relpath(\"myadd_pack.so\"))\n",
        "fadd2 = tvm.runtime.load_module(temp.relpath(\"myadd_pack.so\"))\n",
        "fadd2(a, b, c)\n",
        "tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} 运行时 API 和线程安全\n",
        "TVM 的编译模块并不依赖于 TVM 编译器。相反，它们只依赖于最小的运行时库。TVM 运行库包装了设备驱动程序，并提供线程安全和设备无关的调用到编译的函数。\n",
        "\n",
        "这意味着你可以从任何线程、任何 GPU 上调用已编译的 TVM 函数，只要你已经为该 GPU 编译了代码。\n",
        "```\n",
        "\n",
        "## 生成 OpenCL 代码\n",
        "\n",
        "TVM 提供代码生成功能到多个后端。我们还可以生成 OpenCL 代码或 LLVM 代码，在 CPU 后端运行。\n",
        "\n",
        "下面的代码块生成 OpenCL 代码，在 OpenCL 设备上创建阵列，并验证代码的正确性。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if tgt.kind.name.startswith(\"opencl\"):\n",
        "    fadd_cl = tvm.build(s, [A, B, C], tgt, name=\"myadd\")\n",
        "    print(\"------opencl code------\")\n",
        "    print(fadd_cl.imported_modules[0].get_source())\n",
        "    dev = tvm.cl(0)\n",
        "    n = 1024\n",
        "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
        "    b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
        "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
        "    fadd_cl(a, b, c)\n",
        "    tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} TE 调度原语\n",
        "TVM 包括一些不同的调度原语：\n",
        "\n",
        "- `split`：将指定的轴按定义的因子（factor）分成两个轴。\n",
        "- `tile`：将计算按定义的 factor 分成两个轴。\n",
        "- `fuse`：融合计算的两个连续轴。\n",
        "- `reorder`：可以将计算的轴重新排序到定义的顺序。\n",
        "- `bind`：可以将计算绑定到特定的线程，在 GPU 编程中很有用。\n",
        "- `compute_at`：默认情况下，TVM 会在函数的最外层计算张量，也就是默认的根。`compute_at` 指定一个张量应该在另一个运算符的第一个计算轴上计算。\n",
        "- `compute_inline`：当标记为内联时，计算将被展开，然后插入到需要张量的地址中。\n",
        "- `compute_root`：将计算移到函数的最外层，或根部。这意味着该阶段的计算将在进入下一阶段之前被完全计算。\n",
        "\n",
        "这些原语的完整描述可以在 [调度原语](schedule_primitives) 文档页中找到。\n",
        "```\n",
        "\n",
        "## 实例2：用 TE 手动优化矩阵乘法\n",
        "\n",
        "现在我们将考虑第二个更高级的例子，演示仅用 18 行 python 代码，TVM 如何将普通的矩阵乘法运算加快 18 倍。\n",
        "\n",
        "矩阵乘法是计算密集型运算。为了获得良好的 CPU 性能，有两个重要的优化措施：\n",
        "\n",
        "1. 提高内存访问的高速缓存命中率。复杂的数值计算和热点内存（hot-spot memory）访问都可以通过高缓存命中率（high cache hit rate）来加速。这就要求我们将原点内存（origin ）访问模式转化为符合高速缓存策略的模式。\n",
        "2. SIMD（单指令多数据），也被称为矢量处理单元。在每个周期中，SIMD 可以处理一小批数据，而不是处理一个单一的值。这就要求我们将循环体中的数据访问模式转化为统一模式，以便 LLVM 后端可以将其降低到 SIMD。\n",
        "\n",
        "本教程中使用的技术是 [资源库](https://github.com/flame/how-to-optimize-gemm) 中提到的技巧的一个子集。其中一些已经被 TVM 抽象自动应用了，但由于 TVM 的限制，其中一些不能自动应用。\n",
        "\n",
        "### 准备和性能基线\n",
        "\n",
        "我们首先收集 `numpy` 实现矩阵乘法的性能数据。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Numpy running time: 0.004050\n"
          ]
        }
      ],
      "source": [
        "import tvm\n",
        "from tvm import te\n",
        "import numpy\n",
        "\n",
        "# The size of the matrix\n",
        "# (M, K) x (K, N)\n",
        "# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.\n",
        "M = 1024\n",
        "K = 1024\n",
        "N = 1024\n",
        "\n",
        "# The default tensor data type in tvm\n",
        "dtype = \"float32\"\n",
        "\n",
        "# You will want to adjust the target to match any CPU vector extensions you\n",
        "# might have. For example, if you're using using Intel AVX2 (Advanced Vector\n",
        "# Extensions) ISA for SIMD, you can get the best performance by changing the\n",
        "# following line to ``llvm -mcpu=core-avx2``, or specific type of CPU you use.\n",
        "# Recall that you're using llvm, you can get this information from the command\n",
        "# ``llc --version`` to get the CPU type, and you can check ``/proc/cpuinfo``\n",
        "# for additional extensions that your processor might support.\n",
        "\n",
        "target = tvm.target.Target(target=\"llvm\", host=\"llvm\")\n",
        "dev = tvm.device(target.kind.name, 0)\n",
        "\n",
        "# Random generated tensor for testing\n",
        "a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev)\n",
        "b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev)\n",
        "\n",
        "# Repeatedly perform a matrix multiplication to get a performance baseline\n",
        "# for the default numpy implementation\n",
        "np_repeat = 100\n",
        "np_running_time = timeit.timeit(\n",
        "    setup=\"import numpy\\n\"\n",
        "    \"M = \" + str(M) + \"\\n\"\n",
        "    \"K = \" + str(K) + \"\\n\"\n",
        "    \"N = \" + str(N) + \"\\n\"\n",
        "    'dtype = \"float32\"\\n'\n",
        "    \"a = numpy.random.rand(M, K).astype(dtype)\\n\"\n",
        "    \"b = numpy.random.rand(K, N).astype(dtype)\\n\",\n",
        "    stmt=\"answer = numpy.dot(a, b)\",\n",
        "    number=np_repeat,\n",
        ")\n",
        "print(\"Numpy running time: %f\" % (np_running_time / np_repeat))\n",
        "\n",
        "answer = numpy.dot(a.numpy(), b.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在用 TVM TE 编写基本的矩阵乘法，并验证它产生的结果与 `numpy` 的实现相同。我们还写了一个函数，它将帮助衡量调度优化的性能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "none: 2.171842\n"
          ]
        }
      ],
      "source": [
        "# TVM Matrix Multiplication using TE\n",
        "k = te.reduce_axis((0, K), \"k\")\n",
        "A = te.placeholder((M, K), name=\"A\")\n",
        "B = te.placeholder((K, N), name=\"B\")\n",
        "C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name=\"C\")\n",
        "\n",
        "# Default schedule\n",
        "s = te.create_schedule(C.op)\n",
        "func = tvm.build(s, [A, B, C], target=target, name=\"mmult\")\n",
        "\n",
        "c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)\n",
        "func(a, b, c)\n",
        "np.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)\n",
        "\n",
        "\n",
        "def evaluate_operation(s, vars, target, name, optimization, log):\n",
        "    func = tvm.build(s, [A, B, C], target=target, name=\"mmult\")\n",
        "    assert func\n",
        "\n",
        "    c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)\n",
        "    func(a, b, c)\n",
        "    np.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)\n",
        "\n",
        "    evaluator = func.time_evaluator(func.entry_name, dev, number=10)\n",
        "    mean_time = evaluator(a, b, c).mean\n",
        "    print(\"%s: %f\" % (optimization, mean_time))\n",
        "    log.append((optimization, mean_time))\n",
        "\n",
        "\n",
        "log = []\n",
        "evaluate_operation(s, [A, B, C], target=target, name=\"mmult\", optimization=\"none\", log=log)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "让我们来看看使用 TVM 低级函数的运算器和默认调度的中间表示。请注意这个实现基本上是矩阵乘法的天真实现，在 A 和 B 矩阵的索引上使用三个嵌套循环。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x, y <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>):\n",
              "            C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            C_1[x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">1024</span>):\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> y\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[cse_var_1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> C_1[cse_var_1] <span style=\"color: #AA22FF; font-weight: bold\">+</span> A_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> B_1[k <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化1：分块\n",
        "\n",
        "提高缓冲区命中率的一个重要技巧是分块，在这个过程中，你的内存访问结构是在一个块的内部有一个小的邻域，具有很高的内存定位性。在本教程中，选择 32 的块因子。这将导致块充满 32 * 32 * sizeof(float) 的内存区域。这相当于 4KB 的缓存大小，而 L1 缓存的参考缓存大小为 32KB。\n",
        "\n",
        "首先为 ``C`` 操作创建默认的调度，然后用指定的块因子对其应用 `tile` 调度原语，调度原语返回所产生的循环顺序，从最外层到最内层，作为向量 `[x_outer, y_outer, x_inner, y_inner]`。然后得到运算输出的归约轴，并使用 4 的因子对其进行 split 操作。这个因子并不直接影响现在正在进行的分块优化，但在以后应用矢量化时将会很有用。\n",
        "\n",
        "现在操作已经被分块了，可以重新调度计算的顺序，把运算的归约轴放到计算的最外层循环中，帮助保证被分块的数据仍然在缓存中。这样就完成了调度，可以建立并测试与原生的调度相比的性能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "blocking: 0.242442\n"
          ]
        }
      ],
      "source": [
        "bn = 32\n",
        "\n",
        "# Blocking by loop tiling\n",
        "xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)\n",
        "(k,) = s[C].op.reduce_axis\n",
        "ko, ki = s[C].split(k, factor=4)\n",
        "\n",
        "# Hoist reduction domain outside the blocking loop\n",
        "s[C].reorder(xo, yo, ko, ki, xi, yi)\n",
        "\n",
        "evaluate_operation(s, [A, B, C], target=target, name=\"mmult\", optimization=\"blocking\", log=log)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "通过重新安排计算顺序以利用缓存，你应该看到计算的性能有了明显的改善。现在，打印内部表示，并将其与原始表示进行比较。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer, y_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "            C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_inner_init, y_inner_init <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "                C_1[\n",
              "                    x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner_init\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, k_inner, x_inner, y_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "                cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[cse_var_1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_1[cse_var_1]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> A_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">*</span> B_1[k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner]\n",
              "                )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化 2: 矢量化\n",
        "\n",
        "另一个重要的优化技巧是矢量化。当内存访问模式是统一的，编译器可以检测到这种模式并将连续的内存传递给 SIMD 矢量处理器。在 TVM 中，我们可以使用 ``vectorize`` 接口来提示编译器这种模式，利用这一硬件特性。\n",
        "\n",
        "在本教程中，我们选择对内循环的行数据进行矢量化，因为在我们之前的优化中，它已经是缓存友好的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "vectorization: 0.248895\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer, y_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "            C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_inner_init <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                C_1[\n",
              "                    x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">32</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, k_inner, x_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">32</span>):\n",
              "                cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner], <span style=\"color: #008000\">32</span>)\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">*</span> B_1[\n",
              "                        k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3 : k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                    ]\n",
              "                )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Apply the vectorization optimization\n",
        "s[C].vectorize(yi)\n",
        "\n",
        "evaluate_operation(s, [A, B, C], target=target, name=\"mmult\", optimization=\"vectorization\", log=log)\n",
        "\n",
        "# The generalized IR after vectorization\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化3：循环交换\n",
        "\n",
        "如果我们看一下上面的 IR，我们可以看到内循环的行数据被矢量化，B 被转化为 PackedB（这从内循环的 `(float32x32*)B2` 部分可以看出）。现在 PackedB 的遍历是顺序的。所以我们要看一下 A 的访问模式。在当前的计划中，A 是被逐列访问的，这对缓冲区不友好。如果我们改变 `ki` 和内轴 `xi` 的嵌套循环顺序，A 矩阵的访问模式将对缓存更友好。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loop permutation: 0.118851\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer, y_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "            C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_inner_init <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                C_1[\n",
              "                    x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">32</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, x_inner, k_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">4</span>):\n",
              "                cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner], <span style=\"color: #008000\">32</span>)\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">*</span> B_1[\n",
              "                        k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3 : k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_3\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                    ]\n",
              "                )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "s = te.create_schedule(C.op)\n",
        "xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)\n",
        "(k,) = s[C].op.reduce_axis\n",
        "ko, ki = s[C].split(k, factor=4)\n",
        "\n",
        "# re-ordering\n",
        "s[C].reorder(xo, yo, ko, xi, ki, yi)\n",
        "s[C].vectorize(yi)\n",
        "\n",
        "evaluate_operation(\n",
        "    s, [A, B, C], target=target, name=\"mmult\", optimization=\"loop permutation\", log=log\n",
        ")\n",
        "\n",
        "# Again, print the new generalized IR\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化4：数组打包\n",
        "\n",
        "另一个重要的技巧是数组打包。这个技巧是对数组的存储维度进行重新排序，将某些维度上的连续访问模式转换为扁平化后的顺序模式。\n",
        "\n",
        "![](images/array-packing.png)\n",
        "\n",
        "正如上图所示，在分块计算后，可以观察到 B 的数组访问模式（扁平化后），它是有规律的，但是不连续的。期望经过一些转换后，可以得到连续的访问模式。通过将 `[16][16]` 数组重新排序为 `[16/4][16][4]` 数组，当从打包的数组中抓取相应的值时，B 的访问模式将是连续的。\n",
        "\n",
        "为了实现这一目标，将不得不从新的默认调度开始，考虑到 B 的新包装，值得花点时间来评论一下。TE 是编写优化运算符的强大而富有表现力的语言，但它往往需要对你所编写的底层算法、数据结构和硬件目标有一些了解。在本教程的后面，将讨论一些让 TVM 承担这一负载的选项。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "array packing: 0.138342\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        packedB <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">32768</span>], <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        packedB_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">32768</span>,), <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, data<span style=\"color: #AA22FF; font-weight: bold\">=</span>packedB)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">32</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> y <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">1024</span>):\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                packedB_1[x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y] <span style=\"color: #AA22FF; font-weight: bold\">=</span> B_1[\n",
              "                    y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer, y_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "            C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_inner_init <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                C_1[\n",
              "                    x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">32</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, x_inner, k_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">4</span>):\n",
              "                cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span>\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner], <span style=\"color: #008000\">32</span>)\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_inner]\n",
              "                )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# 必须重写算法。\n",
        "packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name=\"packedB\")\n",
        "C = te.compute(\n",
        "    (M, N),\n",
        "    lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),\n",
        "    name=\"C\",\n",
        ")\n",
        "\n",
        "s = te.create_schedule(C.op)\n",
        "\n",
        "xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)\n",
        "(k,) = s[C].op.reduce_axis\n",
        "ko, ki = s[C].split(k, factor=4)\n",
        "\n",
        "s[C].reorder(xo, yo, ko, xi, ki, yi)\n",
        "s[C].vectorize(yi)\n",
        "\n",
        "x, y, z = s[packedB].op.axis\n",
        "s[packedB].vectorize(z)\n",
        "s[packedB].parallel(x)\n",
        "\n",
        "evaluate_operation(s, [A, B, C], target=target, name=\"mmult\", optimization=\"array packing\", log=log)\n",
        "\n",
        "# Here is the generated IR after array packing.\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化 5：通过缓存优化块的写入\n",
        "\n",
        "到目前为止，我们所有的优化都集中在有效地访问和计算 `A` 和 `B` 矩阵的数据以计算 `C` 矩阵上。在阻塞优化之后，运算器将逐块地将结果写入 `C`，而且访问模式不是顺序的。我们可以通过使用一个顺序缓存数组来解决这个问题，使用 `cache_write`、`compute_at` 和 `unroll` 的组合来保存块结果，并在所有块结果准备好后写入 `C`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "block caching: 0.120341\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        packedB <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">32768</span>], <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        C_global <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">1024</span>], <span style=\"color: #BA2121\">&quot;float32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        packedB_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">32768</span>,), <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, data<span style=\"color: #AA22FF; font-weight: bold\">=</span>packedB)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">32</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> y <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">1024</span>):\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                packedB_1[x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y] <span style=\"color: #AA22FF; font-weight: bold\">=</span> B_1[\n",
              "                    y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer, y_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "            C_global_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C_global)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_c_init <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                C_global_1[x_c_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : x_c_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(\n",
              "                    T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">32</span>\n",
              "                )\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, x_c <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">32</span>):\n",
              "                cse_var_4: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span>\n",
              "                cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_c <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_4\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_c <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_4\n",
              "                A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2]\n",
              "                )\n",
              "                C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>]\n",
              "                )\n",
              "                C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2</span>]\n",
              "                )\n",
              "                C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3</span>]\n",
              "                )\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> x_inner, y_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "                C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                C_1[\n",
              "                    x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> C_global_1[x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "s = te.create_schedule(C.op)\n",
        "\n",
        "# Allocate write cache\n",
        "CC = s.cache_write(C, \"global\")\n",
        "\n",
        "xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)\n",
        "\n",
        "# Write cache is computed at yo\n",
        "s[CC].compute_at(s[C], yo)\n",
        "\n",
        "# New inner axes\n",
        "xc, yc = s[CC].op.axis\n",
        "\n",
        "(k,) = s[CC].op.reduce_axis\n",
        "ko, ki = s[CC].split(k, factor=4)\n",
        "s[CC].reorder(ko, xc, ki, yc)\n",
        "s[CC].unroll(ki)\n",
        "s[CC].vectorize(yc)\n",
        "\n",
        "x, y, z = s[packedB].op.axis\n",
        "s[packedB].vectorize(z)\n",
        "s[packedB].parallel(x)\n",
        "\n",
        "evaluate_operation(s, [A, B, C], target=target, name=\"mmult\", optimization=\"block caching\", log=log)\n",
        "\n",
        "# Here is the generated IR after write cache blocking.\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 优化6：并行化\n",
        "\n",
        "到目前为止，我们的计算只被设计为使用单核。几乎所有的现代处理器都有多个内核，计算可以从并行运行的计算中获益。最后的优化是利用线程级并行化的优势。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "parallelization: 0.018223\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        packedB <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">32768</span>], <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        packedB_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">32768</span>,), <span style=\"color: #BA2121\">&quot;float32x32&quot;</span>, data<span style=\"color: #AA22FF; font-weight: bold\">=</span>packedB)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">32</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> y <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">1024</span>):\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                packedB_1[x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y] <span style=\"color: #AA22FF; font-weight: bold\">=</span> B_1[\n",
              "                    y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : y <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>\n",
              "                ]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> x_outer <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">32</span>):\n",
              "            C_global <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">1024</span>], <span style=\"color: #BA2121\">&quot;float32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> y_outer <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                C_global_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C_global)\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> x_c_init <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">32</span>):\n",
              "                    C_global_1[x_c_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> : x_c_init <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(\n",
              "                        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">32</span>\n",
              "                    )\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> k_outer, x_c <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">32</span>):\n",
              "                    cse_var_4: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span>\n",
              "                    cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_c <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span>\n",
              "                    cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_4\n",
              "                    cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_c <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> cse_var_4\n",
              "                    A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2]\n",
              "                    )\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>]\n",
              "                    )\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2</span>]\n",
              "                    )\n",
              "                    C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        C_global_1[cse_var_3 : cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3</span>], <span style=\"color: #008000\">32</span>) <span style=\"color: #AA22FF; font-weight: bold\">*</span> packedB_1[cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3</span>]\n",
              "                    )\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> x_inner, y_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">32</span>):\n",
              "                    C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                    C_1[\n",
              "                        x_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32768</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner\n",
              "                    ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> C_global_1[x_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">32</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> y_inner]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# parallel\n",
        "s[C].parallel(xo)\n",
        "\n",
        "x, y, z = s[packedB].op.axis\n",
        "s[packedB].vectorize(z)\n",
        "s[packedB].parallel(x)\n",
        "\n",
        "evaluate_operation(\n",
        "    s, [A, B, C], target=target, name=\"mmult\", optimization=\"parallelization\", log=log\n",
        ")\n",
        "\n",
        "# Here is the generated IR after parallelization.\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 矩阵乘法实例总结\n",
        "\n",
        "在应用了上述仅有 18 行代码的简单优化后，我们生成的代码可以开始接近 `numpy` 与 Math Kernel Library（MKL）的性能。由于我们在工作中一直在记录性能，所以我们可以比较一下结果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "            Operator\t              Timing\t         Performance\n",
            "                none\t  2.1718422675999998\t                 1.0\n",
            "            blocking\t 0.24244180110000002\t 0.11162956201598885\n",
            "       vectorization\t 0.24889526580000002\t 0.11460098622863732\n",
            "    loop permutation\t        0.1188508981\t   0.054723540412231\n",
            "       array packing\t        0.1383423394\t 0.06369815223868701\n",
            "       block caching\t 0.12034089109999999\t 0.05540959069416354\n",
            "     parallelization\t        0.0182233174\t 0.00839071864097098\n"
          ]
        }
      ],
      "source": [
        "baseline = log[0][1]\n",
        "print(\"%s\\t%s\\t%s\" % (\"Operator\".rjust(20), \"Timing\".rjust(20), \"Performance\".rjust(20)))\n",
        "for result in log:\n",
        "    print(\n",
        "        \"%s\\t%s\\t%s\"\n",
        "        % (result[0].rjust(20), str(result[1]).rjust(20), str(result[1] / baseline).rjust(20))\n",
        "    )"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "请注意，网页上的输出反映了在一个非独家 Docker 容器上的运行时间，应该被认为是不可靠的。强烈建议你自己运行该教程，观察 TVM 取得的性能提升，并仔细研究每个例子，了解对矩阵乘法操作的迭代改进。\n",
        "\n",
        "## 最后说明和总结\n",
        "\n",
        "如前所述，如何使用 TE 和调度原语进行优化，可能需要对底层架构和算法有一些了解。然而，TE 的设计是作为更复杂的算法的基础，可以搜索潜在的优化。有了这篇关于 TE 的介绍中的知识，我们现在可以开始探索 TVM 如何将调度优化过程自动化。\n",
        "\n",
        "本教程提供了 TVM 张量表达（TE）工作流程的演练，分别介绍了矢量加法和矩阵乘法的例子。一般的工作流程是：\n",
        "\n",
        "- 通过一系列的算子来描述你的计算。\n",
        "- 描述我如何计算使用调度原语。\n",
        "- 编译到想要的目标函数。\n",
        "- 可以选择保存该函数以便以后加载。\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('py38': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    },
    "vscode": {
      "interpreter": {
        "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
